/******************************************************//**
 * C++ library of the Stochastic Gradient Descent (SGD) methods.
 *
 * Copyright (c) 2020-2031 Yi Zhang (zhangyiss@icloud.com)
 * All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 *********************************************************/

#ifndef _SGD_H
#define _SGD_H

#ifndef _cplusplus
extern "C"
{

#include "stddef.h"
#endif

/**
 * @brief    A simple definition of the float type we use here. 
 * Easy to change in the future. For now it is just an alias of the double type.
 */
typedef double sgd_float;

/**
 * @brief      Types of method that could be recognized by the sgd_solver() function.
 */
typedef enum
{
	/**
	 * Classic momentum.
	 */
	SGD_MOMENTUM,

	/**
	 * Nesterov’s accelerated gradient (NAG)
	 */
	SGD_NAG,

	/**
	 * AdaGrad method.
	 */
	SGD_ADAGRAD,

	/**
	 * RMSProp method.
	 */
	SGD_RMSPROP,

	/**
	 * Adam method.
	 */
	SGD_ADAM,

	/**
	 * Nadam method.
	 */
	SGD_NADAM,

	/**
	 * AdaMax method.
	 */
	SGD_ADAMAX,

	/**
	 * AdaBelief method.
	 */
	SGD_ADABELIEF,
} sgd_solver_enum;

/**
 * @brief    Parameters of the SGD methods.
 */
typedef struct
{
	/**
	 * Iteration times for the entire observation set. The default is 300.
	 */
	int iteration;

	/**
	 * Epsilon for convergence test. This parameter determines the accuracy 
	 * with which the solution is to be found. Must be bigger than zero and 
	 * the default is 1e-6.
	 */
	sgd_float epsilon;

	/**
	 * Damping rate of the classic momentum method and the NAG method, which 
	 * is typically given between 0 and 1. The default is 0.01.
	 */
	sgd_float mu;

	/**
	 * Step size of the iteration. The default value is 0.001 for Adam and 0.002
	 * for AdaMax.
	 */
	sgd_float alpha;

	/**
	 * Exponential decay rates for the first order moment estimates. The range of this 
	 * parameter is [0, 1) and the default value is 0.9.
	 */
	sgd_float beta_1;

	/**
	 * Exponential decay rates for the second order moment estimates. The range of this 
	 * parameter is [0, 1) and the default value is 0.999.
	 */
	sgd_float beta_2;

	/**
	 * A small positive number validates the algorithm. The default value is 1e-8.
	 */
	sgd_float sigma;
} sgd_para;

/**
 * @brief    Callback interface for calculating the value of objective function
 * and the corresponding model gradients.
 * 
 * @param    instance   The user data sent for the sgd_solver() functions by the client.
 * @param    x          Pointer of the solution.
 * @param    g          Pointer of the model gradient.
 * @param    n_size     Length of the solution.
 * @param    m          Index of the observation.
 * 
 * @return   Value of objective function.
 */
typedef sgd_float (*sgd_evaulate_ptr)(void *instance, const sgd_float *x, sgd_float *g, 
	const int n_size, const int m);

/**
 * @brief    Callback interface for monitoring the progress and terminate the iteration 
 * if necessary.
 * 
 * @param    instance   The user data sent for the sgd_solver() functions by the client.
 * @param    fx         Current value of the objective function.
 * @param    x          Current solution.
 * @param    g          Current model gradients.
 * @param    param      User defined iteration parameters.
 * @param    n_size     Length of the solution array.
 * @param    k          Times of the iteration.
 * 
 * @return   int        Zero to continue the optimization process. Otherwise, the optimization 
 * process will be terminated.
 */
typedef int (*sgd_progress_ptr)(void *instance, sgd_float fx, const sgd_float *x, const sgd_float *g, 
	const sgd_para *param, const int n_size, const int k);

/**
 * @brief      Locate memory for a sgd_float pointer type.
 *
 * @param[in]  n_size  Size of the sgd_float array.
 *
 * @return     Pointer of the data
 */
sgd_float *sgd_malloc(const int n_size);

/**
 * @brief      Destroy memory used by the sgd_float type array.
 *
 * @param      x     Pointer of the array.
 */
void sgd_free(sgd_float *x);

/**
 * @brief      Return a sgd_para type instance with default values.
 *
 * @return     A sgd_para type instance.
 */
sgd_para sgd_default_parameters();

/**
 * @brief      Return a string explanation for the sgd_solver() function's return values.
 *
 * @param[in]  er_index  The error index returned by the sgd_solver() function.
 *
 * @return     A string explanation of the error.
 */
const char* sgd_error_str(int er_index);

/**
 * @brief      An Adam solver function.
 * 
 * @note       The size of all arrays must be equal to n_size.
 *
 * @param[in]  Evafp       Callback function for calculating the objective function and its gradient.
 * @param[in]  Profp       Callback function for monitoring the optimization process.
 * @param      fx          Returned best value of the objective function by now.
 * @param      m           Pointer of the solution array.
 * @param[in]  n_size      Length of the solution array.
 * @param[in]  m_size      Length of the observation.
 * @param[in]  param       Parameters of optimization process.
 * @param      instance    The user data sent for the function by the client.
 * @param      solver_id   Solver type used to solve the objective. The default value is SGD_ADAM.
 *
 * @return     Status of the function.
 */
int sgd_solver(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *m, 
	const int n_size, const int m_size, const sgd_para *param, void *instance, 
	sgd_solver_enum solver_id = SGD_ADAM);

#ifndef _cplusplus
}
#endif

#endif // _SGD_H